{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import torch.nn as nn\n",
    "from torchvision import models\n",
    "import torch\n",
    "\n",
    "resnet101 = models.resnet101(pretrained=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./2.png\" alt=\"图片替换文本\" width=\"500\" height=\"313\" align=\"bottom\" />"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RRB(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels):\n",
    "        super(RRB, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)\n",
    "        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.bn = nn.BatchNorm2d(out_channels)\n",
    "        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        res = self.conv2(x)\n",
    "        res = self.bn(res)\n",
    "        res = self.relu(res)\n",
    "        res = self.conv3(res)\n",
    "        return self.relu(x + res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./3.png\" alt=\"图片替换文本\" width=\"500\" height=\"313\" align=\"bottom\" />"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CAB(nn.Module):\n",
    "\n",
    "    def __init__(self, in_channels, out_channels):\n",
    "        super(CAB, self).__init__()\n",
    "        self.global_pooling = nn.AdaptiveAvgPool2d(1)\n",
    "        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)\n",
    "        self.sigmod = nn.Sigmoid()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x1, x2 = x  # high, low\n",
    "        x = torch.cat([x1, x2], dim=1)\n",
    "        x = self.global_pooling(x)\n",
    "        x = self.conv1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.sigmod(x)\n",
    "        x2 = x * x2\n",
    "        res = x2 + x1\n",
    "        return res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./1.png\" alt=\"图片替换文本\" width=\"500\" height=\"313\" align=\"bottom\" />"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DFN(nn.Module):\n",
    "    def __init__(self, num_class=21):\n",
    "        super(DFN, self).__init__()\n",
    "        self.num_class = num_class\n",
    "        self.layer0 = nn.Sequential(resnet101.conv1, resnet101.bn1, resnet101.relu)\n",
    "        self.layer1 = nn.Sequential(resnet101.maxpool, resnet101.layer1)\n",
    "        self.layer2 = resnet101.layer2\n",
    "        self.layer3 = resnet101.layer3\n",
    "        self.layer4 = resnet101.layer4\n",
    "\n",
    "        # this is for smooth network\n",
    "        self.out_conv = nn.Conv2d(2048, self.num_class, kernel_size=1, stride=1)\n",
    "        self.global_pool = nn.AdaptiveAvgPool2d(1)\n",
    "        self.cab1 = CAB(self.num_class*2, self.num_class)\n",
    "        self.cab2 = CAB(self.num_class*2, self.num_class)\n",
    "        self.cab3 = CAB(self.num_class*2, self.num_class)\n",
    "        self.cab4 = CAB(self.num_class*2, self.num_class)\n",
    "\n",
    "        self.rrb_d_1 = RRB(256, self.num_class)\n",
    "        self.rrb_d_2 = RRB(512, self.num_class)\n",
    "        self.rrb_d_3 = RRB(1024, self.num_class)\n",
    "        self.rrb_d_4 = RRB(2048, self.num_class)\n",
    "\n",
    "        self.upsample = nn.Upsample(scale_factor=2,mode=\"bilinear\")\n",
    "        self.upsample_4 = nn.Upsample(scale_factor=4, mode=\"bilinear\")\n",
    "        self.upsample_8 = nn.Upsample(scale_factor=8, mode=\"bilinear\")\n",
    "\n",
    "        self.rrb_u_4 = RRB(self.num_class,self.num_class)\n",
    "        self.rrb_u_3 = RRB(self.num_class,self.num_class)\n",
    "        self.rrb_u_2 = RRB(self.num_class,self.num_class)\n",
    "        self.rrb_u_1 = RRB(self.num_class,self.num_class)\n",
    "\n",
    "        # this is for boarder net work\n",
    "        self.rrb_db_1 = RRB(256, self.num_class)\n",
    "        self.rrb_db_2 = RRB(512, self.num_class)\n",
    "        self.rrb_db_3 = RRB(1024, self.num_class)\n",
    "        self.rrb_db_4 = RRB(2048, self.num_class)\n",
    "\n",
    "        self.rrb_trans_1 = RRB(self.num_class,self.num_class)\n",
    "        self.rrb_trans_2 = RRB(self.num_class,self.num_class)\n",
    "        self.rrb_trans_3 = RRB(self.num_class,self.num_class)\n",
    "\n",
    "    def forward(self, x):\n",
    "        f0 = self.layer0(x)  # 256, 256, 64\n",
    "        f1 = self.layer1(f0)  # 128, 128, 256\n",
    "        f2 = self.layer2(f1)  # 64, 64, 512\n",
    "        f3 = self.layer3(f2)  # 32, 32, 1024\n",
    "        f4 = self.layer4(f3)  # 16, 16, 2048\n",
    "\n",
    "        # for border network\n",
    "        res1 = self.rrb_db_1(f1)\n",
    "        res1 = self.rrb_trans_1(res1 + self.upsample(self.rrb_db_2(f2)))\n",
    "        res1 = self.rrb_trans_2(res1 + self.upsample_4(self.rrb_db_3(f3)))\n",
    "        res1 = self.rrb_trans_3(res1 + self.upsample_8(self.rrb_db_4(f4)))      # 128, 128, 21\n",
    "\n",
    "        # for smooth network\n",
    "        res2 = self.out_conv(f4)    # 16, 16, 21\n",
    "        res2 = self.global_pool(res2)  #\n",
    "        res2 = nn.Upsample(size=f4.size()[2:],mode=\"nearest\")(res2)     # 16, 16, 21\n",
    "\n",
    "        f4 = self.rrb_d_4(f4)\n",
    "        res2 = self.cab4([res2, f4])\n",
    "        res2 = self.rrb_u_4(res2)\n",
    "\n",
    "        f3 = self.rrb_d_3(f3)\n",
    "        res2 = self.cab3([self.upsample(res2), f3])\n",
    "        res2 =self.rrb_u_3(res2)\n",
    "\n",
    "        f2 = self.rrb_d_2(f2)\n",
    "        res2 = self.cab2([self.upsample(res2), f2])\n",
    "        res2 =self.rrb_u_2(res2)\n",
    "\n",
    "        f1 = self.rrb_d_1(f1)\n",
    "        res2 = self.cab1([self.upsample(res2), f1])\n",
    "        res2 = self.rrb_u_1(res2)\n",
    "\n",
    "        return res1, res2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\nanfe\\anaconda3\\lib\\site-packages\\torch\\nn\\functional.py:2539: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n",
      "  \"See the documentation of nn.Upsample for details.\".format(mode))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 21, 128, 128]) torch.Size([1, 21, 128, 128])\n"
     ]
    }
   ],
   "source": [
    "if __name__ == '__main__':\n",
    "    import torch as t\n",
    "    model = DFN(21)\n",
    "    model.eval()\n",
    "    image = t.randn(1, 3, 512, 512)\n",
    "    res1, res2 = model(image)\n",
    "    print(res1.size(), res2.size())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
